In [ ]:
# Please ignore these variable, they only provide options for our CI system.
args = []
abort_after_one = False

Tutorial: Federated learning with websockets and federated averaging with possible solutions for problem you might face

This notebook will discuss detailed steps and problems you might face when going through these steps

Make sure you have correct websocket-client library because if you have another websocket library installed on top of websocket-client when you run this command import websocket it try will access that additional websocket library first because websocket-client is also called imported into your python script by import websocket and when you try to create connection with this command websocket.create_connection() this causes websocket don't have any module named create_connection Solution: in terminal activate that environment where syft is installed run pip uninstall websocket to remove any additional websocket libraries then run pip install --upgrade websocket_client

Authors:

  • midokura-silvia

Preparation: start the websocket server workers

Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

So first, you need to cd to the folder where this notebook and other additional files for running server and client are

for example in windows 10

cd (path till projects directory) \python_projects\websockets-example-MNIST

Note: Don't copy paste the path above because this is purely for the sake example your path may differ depending on your OS and project folder

because if you don't when you try to run python start_websocket_servers.py command in terminal this script open sub processes with python which runs other scripts that starts websocket server workers and only the name of the file with its extension is mentioned because the file's path may vary. we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

python start_websocket_servers.py

Setting up the websocket client workers

We first need to perform the imports and setup some arguments and variables.


In [ ]:
%load_ext autoreload
%autoreload 2

In [ ]:
import sys
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
import torch
from torchvision import datasets, transforms

from syft.frameworks.torch.fl import *

In [ ]:
import run_websocket_client as rwc

In [ ]:
args = rwc.define_and_get_arguments(args=args)
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Now let's instantiate the websocket client workers, our local access point to the remote workers. Note that this step will fail, if the websocket server workers are not running.


In [ ]:
hook = sy.TorchHook(torch)

kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

workers = [alice, bob, charlie]
print(workers)

Prepare and distribute the training data

We will use the MNIST dataset and distribute the data randomly onto the workers. This is not realistic for a federated training setup, where the data would normally already be available at the remote workers.

We instantiate two FederatedDataLoaders, one for the train and one for the test set of the MNIST dataset.

If you run into BrokenPipe errors go to the parrent directory of the directory where your project is and delete data folder then restart notebook and try again if the error comes again delete that data folder again run the following command

for example directory for data

(path till projects directory) \python_projects\

directory for project notebook and scripts

(path till projects directory) \python_projects\websockets-example-MNIST

Note: Don't copy paste the path above because this is purely for the sake example your path may differ depending on your OS and project folder


In [ ]:
#run this box only if the the next box gives pipeline error
torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data",
        train=True,download=True))

In [ ]:
federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(
        "../data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ).federate(tuple(workers)),
    batch_size=args.batch_size,
    shuffle=True,
    iter_per_worker=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True
)

Next, we need to instantiate the machine learning model. It is a small neural network with 2 convolutional and two fully connected layers. It uses ReLU activations and max pooling.


In [ ]:
model = rwc.Net().to(device)
print(model)

In [ ]:
import logging
import sys
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s")
handler.setFormatter(formatter)
logger.handlers = [handler]

Let's start the training

Now we are ready to start the federated training. We will perform training over a given number of batches separately on each worker and then calculate the federated average of the resulting model and calculate test accuracy over that model.


In [ ]:
for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = rwc.train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches, 
                      abort_after_one=abort_after_one)
    rwc.test(model, device, test_loader)

Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

Star PySyft on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at http://slack.openmined.org

Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

OpenMined's Open Collective Page


In [ ]: